/**
 * The input to OpenAI API, directly adopted from openai-node with small tweaks:
 * https://github.com/openai/openai-node/blob/master/src/resources/completions.ts
 *
 * Copyright 2024 OpenAI
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *      http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import { MLCEngineInterface } from "../types";
import {
  InvalidStreamOptionsError,
  SeedTypeError,
  StreamingCountError,
  UnsupportedFieldsError,
} from "../error";
import {
  ChatCompletion,
  ChatCompletionStreamOptions,
  CompletionUsage,
  ChatCompletionFinishReason,
} from "./chat_completion";

export class Completions {
  private engine: MLCEngineInterface;

  constructor(engine: MLCEngineInterface) {
    this.engine = engine;
  }

  create(request: CompletionCreateParamsNonStreaming): Promise<Completion>;
  create(
    request: CompletionCreateParamsStreaming,
  ): Promise<AsyncIterable<Completion>>;
  create(
    request: CompletionCreateParamsBase,
  ): Promise<AsyncIterable<Completion> | Completion>;
  create(
    request: CompletionCreateParams,
  ): Promise<AsyncIterable<Completion> | Completion> {
    return this.engine.completion(request);
  }
}

//////////////////////////////// 1. CREATE PARAMS ////////////////////////////////
/**
 * OpenAI completion request protocol.
 *
 * API reference: https://platform.openai.com/docs/api-reference/completions/create
 * Followed: https://github.com/openai/openai-node/blob/master/src/resources/completions.ts
 *
 * @note `model` is excluded. Instead, call `CreateMLCEngine(model)` or `engine.reload(model)` explicitly before calling this API.
 */
export interface CompletionCreateParamsBase {
  /**
   * The prompt(s) to generate completions for, encoded as a string.
   */
  prompt: string;

  /**
   * Echo back the prompt in addition to the completion
   */
  echo?: boolean | null;

  /**
   * Number between -2.0 and 2.0. Positive values penalize new tokens based on their
   * existing frequency in the text so far, decreasing the model's likelihood to
   * repeat the same line verbatim.
   *
   * [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details)
   */
  frequency_penalty?: number | null;

  /**
   * Modify the likelihood of specified tokens appearing in the completion.
   *
   * Accepts a JSON object that maps tokens (specified by their token ID, which varies per model)
   * to an associated bias value from -100 to 100. Typically, you can see `tokenizer.json` of the
   * model to see which token ID maps to what string. Mathematically, the bias is added to the
   * logits generated by the model prior to sampling. The exact effect will vary per model, but
   * values between -1 and 1 should decrease or increase likelihood of selection; values like -100
   * or 100 should result in a ban or exclusive selection of the relevant token.
   *
   * As an example, you can pass `{"16230": -100}` to prevent the `Hello` token from being
   * generated in Mistral-7B-Instruct-v0.2, according to the mapping in
   * https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/raw/main/tokenizer.json.
   *
   * @note For stateful and customizable / flexible logit processing, see `webllm.LogitProcessor`.
   * @note If used in combination with `webllm.LogitProcessor`, `logit_bias` is applied after
   * `LogitProcessor.processLogits()` is called.
   */
  logit_bias?: Record<string, number> | null;

  /**
   * Whether to return log probabilities of the output tokens or not.
   *
   * If true, returns the log probabilities of each output token returned in the `content` of
   * `message`.
   */
  logprobs?: boolean | null;

  /**
   * An integer between 0 and 5 specifying the number of most likely tokens to return
   * at each token position, each with an associated log probability. `logprobs` must
   * be set to `true` if this parameter is used.
   */
  top_logprobs?: number | null;

  /**
   * The maximum number of [tokens](/tokenizer) that can be generated in the
   * completion.
   *
   * The total length of input tokens and generated tokens is limited by the model's
   * context length.
   */
  max_tokens?: number | null;

  /**
   * How many completions to generate for each prompt.
   */
  n?: number | null;

  /**
   * Number between -2.0 and 2.0. Positive values penalize new tokens based on
   * whether they appear in the text so far, increasing the model's likelihood to
   * talk about new topics.
   *
   * [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details)
   */
  presence_penalty?: number | null;

  /**
   * If specified, our system will make a best effort to sample deterministically,
   * such that repeated requests with the same `seed` and parameters should return
   * the same result.
   *
   * @note Seeding is done on a request-level rather than choice-level. That is, if `n > 1`, you
   * would still get different content for each `Choice`. But if two requests with `n = 2` are
   * processed with the same seed, the two results should be the same (two choices are different).
   */
  seed?: number | null;

  /**
   * Up to 4 sequences where the API will stop generating further tokens. The
   * returned text will not contain the stop sequence.
   */
  stop?: string | null | Array<string>;

  /**
   * If set, partial deltas will be sent. It will be terminated by an empty chunk.
   */
  stream?: boolean | null;

  /**
   * Options for streaming response. Only set this when you set `stream: true`.
   */
  stream_options?: ChatCompletionStreamOptions | null;

  /**
   * What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
   * make the output more random, while lower values like 0.2 will make it more
   * focused and deterministic.
   *
   * We generally recommend altering this or `top_p` but not both.
   */
  temperature?: number | null;

  /**
   * An alternative to sampling with temperature, called nucleus sampling, where the
   * model considers the results of the tokens with top_p probability mass. So 0.1
   * means only the tokens comprising the top 10% probability mass are considered.
   *
   * We generally recommend altering this or `temperature` but not both.
   */
  top_p?: number | null;

  /**
   * If true, will ignore stop string and stop token and generate until max_tokens hit.
   * If unset, will treat as false.
   */
  ignore_eos?: boolean;

  /**
   * ID of the model to use. This equals to `ModelRecord.model_id`, which needs to either be in
   * `webllm.prebuiltAppConfig` or in `engineConfig.appConfig`.
   *
   * @note Call `CreateMLCEngine(model)` or `engine.reload(model)` ahead of time.
   * @note If only one model is loaded in the engine, this field is optional. If multiple models
   *   are loaded, this is required.
   */
  model?: string | null;

  //////////////// BELOW FIELDS NOT SUPPORTED YET ////////////////

  /**
   * The suffix that comes after a completion of inserted text.
   *
   * @note This field is not supported.
   */
  suffix?: string | null;

  /**
   * A unique identifier representing your end-user, which can help OpenAI to monitor
   * and detect abuse.
   *
   * @note This field is not supported.
   */
  user?: string;

  /**
   * Generates `best_of` completions server-side and returns the "best" (the one with
   * the highest log probability per token). Results cannot be streamed.
   *
   * When used with `n`, `best_of` controls the number of candidate completions and
   * `n` specifies how many to return – `best_of` must be greater than `n`.
   *
   * @note This field is not supported.
   */
  best_of?: number | null;
}

export type CompletionCreateParams =
  | CompletionCreateParamsNonStreaming
  | CompletionCreateParamsStreaming;

export interface CompletionCreateParamsNonStreaming
  extends CompletionCreateParamsBase {
  /**
   * If set, partial deltas will be sent. It will be terminated by an empty chunk.
   */
  stream?: false | null;
}

export interface CompletionCreateParamsStreaming
  extends CompletionCreateParamsBase {
  /**
   * If set, partial deltas will be sent. It will be terminated by an empty chunk.
   */
  stream: true;
}

//////////////////////////////// 2. RESPONSE ////////////////////////////////
/**
 * Represents a completion response returned by model, based on the provided input.
 */
export interface Completion {
  /**
   * A unique identifier for the completion.
   */
  id: string;

  /**
   * The list of completion choices the model generated for the input prompt.
   */
  choices: Array<CompletionChoice>;

  /**
   * The Unix timestamp (in seconds) of when the completion was created.
   */
  created: number;

  /**
   * The model used for completion.
   */
  model: string;

  /**
   * The object type, which is always "text_completion"
   */
  object: "text_completion";

  /**
   * This fingerprint represents the backend configuration that the model runs with.
   *
   * Can be used in conjunction with the `seed` request parameter to understand when
   * backend changes have been made that might impact determinism.
   *
   * @note Not supported yet.
   */
  system_fingerprint?: string;

  /**
   * Usage statistics for the completion request.
   */
  usage?: CompletionUsage;
}

export interface CompletionChoice {
  /**
   * The reason the model stopped generating tokens. This will be `stop` if the model
   * hit a natural stop point or a provided stop sequence, or `length` if the maximum
   * number of tokens specified in the request was reached.
   */
  finish_reason: ChatCompletionFinishReason | null;

  index: number;

  /**
   * A list of message content tokens with log probability information.
   * @note Different from openai-node, we reuse ChatCompletion's Logprobs.
   */
  logprobs?: ChatCompletion.Choice.Logprobs | null;

  text: string;
}

//////////////////////////////// 3. POST INIT ////////////////////////////////

export const CompletionCreateParamsUnsupportedFields: Array<string> = [
  "suffix",
  "user",
  "best_of",
];

/**
 * Post init and verify whether the input of the request is valid. Thus, this function can throw
 * error or in-place update request.
 * @param request User's input request.
 * @param currentModelId The current model loaded that will perform this request.
 */
export function postInitAndCheckFields(
  request: CompletionCreateParams,
  // eslint-disable-next-line @typescript-eslint/no-unused-vars
  currentModelId: string,
): void {
  // 1. Check unsupported fields in request
  const unsupported: Array<string> = [];
  CompletionCreateParamsUnsupportedFields.forEach((field) => {
    if (field in request) {
      unsupported.push(field);
    }
  });
  if (unsupported.length > 0) {
    throw new UnsupportedFieldsError(unsupported, "CompletionCreateParams");
  }

  // 2. If streaming, n cannot be > 1, since we cannot manage multiple sequences at once
  if (request.stream && request.n && request.n > 1) {
    throw new StreamingCountError();
  }

  // 3. Seed should be an integer
  if (request.seed !== undefined && request.seed !== null) {
    if (!Number.isInteger(request.seed)) {
      throw new SeedTypeError(request.seed);
    }
  }

  // 4. Only set stream_options when streaming
  if (request.stream_options !== undefined && request.stream_options !== null) {
    if (!request.stream) {
      throw new InvalidStreamOptionsError();
    }
  }
}
